-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable Pipeline to get device from model #30534
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
Could you add a test?
@faaany are we sure that At most I see |
sure, in which test file should I put this test? |
Good point! Yes, I know that Flax model doesn't have "device". How about moving it inside Furthermore, I removed the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating and handling the torch case!
Only request is to add a test.
@muellerzr could you give a quick review as you correctly spotted and highlighted the torch vs. other frameworks case?
Hi @amyeroberts, sorry for the late response. We had a long holiday here in China. Unit tests are added. Let me explain more about in detail: There are 3 possibilities for model.device: There are 2 possibilities for pipeline.device: Sincea2&b2 is trivial, my unit tests cover the cases a1&b1, a1&b2, a3&b1 and a3&b2. Pls have a review, thx! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great - thanks for adding the tests and the explanation!
cc @muellerzr For a final double check to make sure this makes sense with accelerate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better, thanks! Agreed post Amy's nit :)
Co-authored-by: amyeroberts <[email protected]>
Thanks for the review! @amyeroberts @muellerzr |
What does this PR do?
Currently, the code above will give an output of
But this is not OK: when users have moved the model to CUDA, Pipeline should not move the model back to CPU without showing any message. This PR makes it possible to let the model stay on its original device. Below is the results after this PR:
@Narsil and @muellerzr